import numpy as np
from scipy import sparse
from abc import ABCMeta, abstractclassmethod
import multiprocessing
from joblib import Parallel, delayed
import sys, os, datetime, matplotlib
import matplotlib.pyplot as plt
import graphlearning as gl

class v_laplace(gl.ssl.ssl):
    def __init__(self, W=None, var='weighted', mode='accum', class_priors=None, lamb=1.0, gamma=0.1, is_hard=True, use_cuda=False, min_iter=50, max_iter=1000):
        super().__init__(W, class_priors)
        self.lamb = lamb
        self.gamma = gamma
        self.use_cuda = use_cuda
        self.min_iter = min_iter
        self.max_iter = max_iter
        self.mode = mode
        self.is_hard = is_hard
        self.var = var
        
        # Setup accuracy filename
        fname ='_v_laplace'
        
        self.accuracy_filename = fname
        # Setup Algoritm name
        self.name = 'v_laplace with lamb=%.4f' % lamb
    
    def _fit(self, train_ind, train_labels, all_labels=None):
        n = self.graph.num_nodes
        unique_labels = np.unique(train_labels)
        k = len(unique_labels)
        
        # Zero-out diagonal for faster convergence
        W = self.graph.weight_matrix
        W = W - sparse.spdiags(W.diagonal(), 0, n, n)
        G = gl.graph(W)
        
        # source term
        onehot = gl.utils.labels_to_onehot(train_labels)
        onehot = onehot - onehot.mean(1, keepdims=True)
        source = np.zeros((n, k))
        source[train_ind] = onehot
        
        # mask-out unlabeled data
        mask = np.zeros((n, k))
        mask[train_ind] = 1

        # Setup matrices
        D = G.degree_matrix(p=1)
        D_inv = G.degree_matrix(p=-1)
        P = D_inv * W.transpose()
        
        # Invariant Distribution
        v = np.zeros(n)
        v[train_ind] = 1
        v = v / np.sum(v)
        deg = G.degree_vector()
        vinf = deg / np.sum(deg)
        RW = W.transpose() * D_inv
        
        
        ut = np.zeros((n, k))
        # Number of iterations
        if self.use_cuda:
            import torch
            D = gl.utils.torch_sparse(D).cuda()
            P = gl.utils.torch_sparse(P).cuda()
            ut = torch.from_numpy(ut).float().cuda()
            train_ind = torch.LongTensor(train_ind).cuda()
            onehot = torch.from_numpy(onehot).float().cuda()
            source = torch.from_numpy(source).float().cuda()
            mask = torch.from_numpy(mask).float().cuda()
        
        T = 0
        best_acc = 0
        while (T < self.min_iter or np.max(np.absolute(v - vinf)) > 1/n) and (T < self.max_iter):
            if self.mode == 'scale':
                if self.var == 'mean':
                    ut_avg = ut.mean(0)
                else:
                    if self.use_cuda:
                        ut_avg = torch.sparse.mm(D, ut).sum(0) / torch.sparse.sum(D)
                    else:
                        ut_avg = (D * ut).sum(0) / D.sum()
                if self.use_cuda:
                    ut = torch.sparse.mm(P, ut) - self.lamb * ut_avg
                else:
                    ut = P * ut - self.lamb * ut_avg
                ut = ut / (1 - self.lamb)
                
            elif self.mode == 'accum':
                if self.var == 'mean':
                    ut_var = ut - ut.mean(0)
                else:
                    if self.use_cuda:
                        ut_var = ut - torch.sparse.mm(D, ut).sum(0) / torch.sparse.sum(D)
                    else:
                        ut_var = ut - (D * ut).sum(0) / D.sum()
                if self.use_cuda:
                    ut = torch.sparse.mm(P, ut) + self.lamb * ut_var
                else:
                    ut = P * ut + self.lamb * ut_var
            
            
            if self.is_hard:
                ut[train_ind] = onehot
            else:
                ut += self.gamma * mask * (source - ut)
            
            v = RW * v
            T += 1
            u = ut.cpu().numpy() if self.use_cuda else ut
        
        
            # Compute accuracy if all labels are provided
            if all_labels is not None:
                self.prob = u
                print(self.prob.max(), self.prob.min())
                labels = self.predict()
                acc = gl.ssl.ssl_accuracy(labels, all_labels, len(train_ind))
                if best_acc < acc:
                    best_acc = max(acc, best_acc)
                print('%d,Accuracy = %.2f'%(T, acc), 'Mean of u is', (D * u).mean() / u.max())
        return u
    

class v_laplace_mbo(gl.ssl.ssl):
    def __init__(self, W=None, class_priors=None, mode='hard', lamb=1.0, gamma=0.1, use_cuda=False, min_iter=50, max_iter=1000, Ns=40, mu=1, T=20):
        super().__init__(W, class_priors)
        self.v_lapalce_model = v_laplace(W, mode=mode, p=p, lamb=lamb, gamma=gamma, use_cuda=use_cuda, min_iter=min_iter, max_iter=max_iter)
        
        self.Ns = Ns
        self.mu = mu
        self.mode = mode
        self.use_cuda = use_cuda
        self.T = T
        self.lamb = lamb
        self.gamma = gamma
        
        # Setup accuracy filename
        fname ='_v_lapalce_mbo'
        
        self.accuracy_filename = fname
        # Setup Algoritm name
        self.name = 'v_lapalce_mbo with lamb=%.2f' % lamb
    
    def _fit(self, train_ind, train_labels, all_labels=None):
        
        n = self.graph.num_nodes
        unique_labels = np.unique(train_labels)
        k = len(unique_labels)
        
        # Zero-out diagonal for faster convergence
        W = self.graph.weight_matrix
        W = W - sparse.spdiags(W.diagonal(), 0, n, n)
        G = gl.graph(W)
        
        # source term
        onehot = gl.utils.labels_to_onehot(train_labels)
        source = np.zeros((n, k))
        source[train_ind] = onehot
        
        # mask-out unlabeled data
        mask = np.zeros((n, k))
        mask[train_ind] = 1
        
        # Setup matrices
        D = G.degree_matrix(p=1)
        D_inv = G.degree_matrix(p=-1)
        P = D_inv * W.transpose()
        
        Db = D_inv * source
        
        # Initialize via V-Poisson Learning
        labels = self.v_poisson_model.fit_predict(train_ind, train_labels, all_labels=all_labels)
        u = gl.utils.labels_to_onehot(labels)
        
        # Time step for stability
        dt = 1 / np.max(G.degree_vector())
        
        # Precompute some things
        P = sparse.identity(n) - dt * G.laplacian()
        
        if self.use_cuda:
            import torch
            D = gl.utils.torch_sparse(D).cuda()
            P = gl.utils.torch_sparse(P).cuda()
            train_ind = torch.LongTensor(train_ind).cuda()
            onehot = torch.from_numpy(onehot).float().cuda()
            source = torch.from_numpy(source).float().cuda()
            mask = torch.from_numpy(mask).float().cuda()
        
        for i in range(self.T):
            ut = torch.from_numpy(u).float().cuda() if self.use_cuda else u
            for j in range(self.Ns):
                ut = torch.sparse.mm(P, ut) if self.use_cuda else P * ut
                if self.use_cuda:
                    ut_var = ut - torch.sparse.mm(D, ut).sum(0) / torch.sparse.sum(D)
                else:
                    ut_var = ut - (D * ut).sum(0) / D.sum()
                ut += dt * self.lamb * ut_var
        
                if self.mode == 'soft':
                    ut += dt * self.gamma * mask * (source - ut)
                elif self.mode == 'hard':
                    ut[train_ind] = onehot
                else:
                    raise NotImplementedError("Only 'soft' and 'hard' modes are implemented!")
            
            u = ut.cpu().numpy() if self.use_cuda else ut
            # Projection step
            self.prob = u
            labels = self.volume_label_projection()
            u = gl.utils.labels_to_onehot(labels)

            #Compute accuracy if all labels are provided
            if all_labels is not None:
                acc = gl.ssl.ssl_accuracy(labels,all_labels,len(train_ind))
                print('%d, Accuracy = %.2f'%(i,acc))

        return u